from __future__ import annotations
import os
import re
import json
import openai
from openai import OpenAI

from utils.util import read_json, read_txt, write_txt, is_json, write_json
from utils.token_count_decorator import token_count_decorator
from planning.src.protocol import Protocol
from planning.src.representation import Representer
from planning.src.compiler import Compiler

class Planner:
    def __init__(self, domain, task) -> None:
        self.domain = domain
        self.task = task
        self.representer = Representer(domain=domain)
        self.atomic_prompt = read_txt("planning/data/prompt/atomic_representation.txt")
        self.original_prompt_1 = read_txt("planning/data/prompt/original_representation_stage1.txt")
        self.original_prompt_2 = read_txt("planning/data/prompt/original_representation_stage2.txt")
        self.pseudocode_to_json_prompt = read_txt("planning/data/prompt/pseudocode_to_json.txt")
        self.dsl_prompt = read_txt("planning/data/prompt/dsl_representation_2.txt")
        self.dsl_feedback_prompt = read_txt("planning/data/prompt/dsl_feedback.txt")
        self.dsl_refine_prompt = read_txt("planning/data/prompt/dsl_refine_3.txt")
        self.multi_dsl_prompt = read_txt("planning/data/prompt/multi_dsl_representation.txt")
        self.multi_dsl_feedback_prompt = read_txt("planning/data/prompt/multi_dsl_feedback.txt")
        self.multi_dsl_refine_prompt = read_txt("planning/data/prompt/multi_dsl_refine_3.txt")

    def plan(self, protocol: Protocol, mode="atomic", method="baseline") -> Protocol:
        '''
        Generate plan for a novel protocol, with different method and representations.

        Args:
            method (str): The strategy for protocol generation. Possible values are:
                - 'baseline': Pure LLM protocol generation, output plan in pseudocode.
                - 'internal': LLM program synthesis, output plan in DSL.
                - 'external': LLM program refinement, output plan in DSL, verified by DSL.
                - 'heuristic': Classic planning task, output plan in DSL.
            mode (str): The mode of representations for planning. Possible values are:
                - 'flatten': Full procedure of similar protocols.
                - 'atomic': Pseudofunctions of similar protocols.
                - 'dsl': DSL specifications of relevant operations.
                - 'multi-dsl': Operation DSL specifications of relevant operations & Production DSL specifications of relevant flowunits
        '''
        if method == "baseline":
            if mode == "atomic":
                # stage-1 generate pseudocode plan
                representations = self.representer.represent(protocol, mode=mode)
                plan_prompt = self.atomic_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{psuedofunctions}", representations)
                pseudocode = self.__smart_fetch(content=plan_prompt, type="python")
                # stage-2 convert pseudocode to json program
                convert_prompt = self.pseudocode_to_json_prompt.replace("{pseudocode}", pseudocode)
                plan = self.__smart_fetch(content=convert_prompt, type="json")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=pseudocode, program=json.loads(plan))
                
            elif mode == "flatten":
                # stage-1 generate nl protocol
                representations = self.representer.represent(protocol, mode=mode)
                nl_plan_prompt = self.original_prompt_1.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{steps}", representations)
                nl_plan = self.__smart_fetch(content=nl_plan_prompt, type="nl")
                # stage-2 generate pseudocode plan
                pseudocode_plan_prompt = self.original_prompt_2.replace("{title}", protocol.title).replace("{protocol}", nl_plan)
                pseudocode = self.__smart_fetch(content=pseudocode_plan_prompt, type="python")
                # stage-3 convert pseudocode to json program
                convert_prompt = self.pseudocode_to_json_prompt.replace("{pseudocode}", pseudocode)
                plan = self.__smart_fetch(content=convert_prompt, type="json")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=pseudocode, program=json.loads(plan))
        
        elif method == "internal":
            if mode == "atomic":
                # stage-1 generate pseudocode plan
                representations = self.representer.represent(protocol, mode="atomic-internal")
                write_txt("test.txt", representations)
                plan_prompt = self.atomic_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{psuedofunctions}", representations)
                pseudocode = self.__smart_fetch(content=plan_prompt, type="python")
                # stage-2 convert pseudocode to json program
                convert_prompt = self.pseudocode_to_json_prompt.replace("{pseudocode}", pseudocode)
                plan = self.__smart_fetch(content=convert_prompt, type="json")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=pseudocode, program=json.loads(plan))
            
            elif mode == "dsl":
                # stage-1 generate DSL program plan
                oper_repr = self.representer.represent(protocol, mode=mode)
                plan_prompt = self.dsl_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{DSL}", oper_repr)
                plan = self.__smart_fetch(content=plan_prompt, type="json", model="gpt-4o-mini")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=plan, program=json.loads(plan))
            
            elif mode == "multi-dsl":
                # stage-1 generate Multi-DSL program plan
                oper_repr, prod_repr = self.representer.represent(protocol, mode=mode)
                # print("Representation Got")
                plan_prompt = self.multi_dsl_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{Operation-DSL}", oper_repr).replace("{Production-DSL}", prod_repr)
                plan = self.__smart_fetch(content=plan_prompt, type="json", model="gpt-4o-mini")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=plan, program=json.loads(plan))
        
        elif method == "external":
            if mode == "dsl":
                oper_repr = self.representer.represent(protocol, mode=mode)
                # stage-1 initial plan generation
                if os.path.exists(path := f"planning_result/{self.domain}/{self.task}/dsl_internal/{protocol.id}.json"):
                    plan = json.dumps(read_json(path)["program"], indent=4, ensure_ascii=False)
                else:
                    plan_prompt = self.dsl_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{DSL}", oper_repr)
                    plan = self.__smart_fetch(content=plan_prompt, type="json")
                # stage-2 feedback-refine-loop
                best_plan = plan
                min_feedback_len = float('inf')
                for i in range(3):
                    compiler = Compiler(program=json.loads(plan), mode="dsl")
                    feedback = compiler.compile()
                    print(len(feedback))
                    if len(feedback) <= min_feedback_len:
                        min_feedback_len = len(feedback)
                        best_plan = plan
                    if len(feedback) < 3 and i > 0:
                        break
                    refine_prompt = self.dsl_refine_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{plan}", plan).replace("{feedback}", str(feedback))
                    plan = self.__smart_fetch(content=refine_prompt, type="json")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=best_plan, program=json.loads(best_plan))
            
            elif mode == "multi-dsl":
                oper_repr, prod_repr = self.representer.represent(protocol, mode=mode)
                # stage-1 initial plan generation
                if os.path.exists(path := f"planning_result/{self.domain}/{self.task}/multi-dsl_internal/{protocol.id}.json"):
                    plan = json.dumps(read_json(path)["program"], indent=4, ensure_ascii=False)
                else:
                    plan_prompt = self.multi_dsl_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{Operation-DSL}", oper_repr).replace("{Production-DSL}", prod_repr)
                    plan = self.__smart_fetch(content=plan_prompt, type="json")
                # stage-2 feedback-refine-loop
                best_plan = plan
                min_feedback_len = float('inf')
                for i in range(3):
                    compiler = Compiler(program=json.loads(plan), mode="multi-dsl")
                    feedback = compiler.compile()
                    print(len(feedback))
                    if len(feedback) <= min_feedback_len:
                        min_feedback_len = len(feedback)
                        best_plan = plan
                    if len(feedback) < 3 and i > 0:
                        break
                    refine_prompt = self.multi_dsl_refine_prompt.replace("{title}", protocol.title).replace("{details}", protocol.description).replace("{plan}", plan).replace("{feedback}", str(feedback))
                    plan = self.__smart_fetch(content=refine_prompt, type="json")
                return Protocol(id=protocol.id, title=protocol.title, description=protocol.description, pseudocode=best_plan, program=json.loads(best_plan))

    def __smart_fetch(self, content, model="gpt-4o-mini", type="nl"):
        for _ in range(5):
            response = self.__chatgpt_function(content=content, gpt_model=model)
            if type == "nl" and response:
                return response.strip()
            elif type == "python":
                pseudocode = re.findall(r'```python([^`]*)```', response, re.DOTALL)
                if len(pseudocode) > 0:
                    return pseudocode[0].strip()
            elif type == "json":
                program = re.findall(r'```json([^`]*)```', response, re.DOTALL)
                if len(program) > 0 and is_json(plan := program[0].strip()):
                    return plan
        raise RuntimeError(f"Failed to fetch a valid response after 5 attempts. Last response was: {response}")
    
    @token_count_decorator(flow="together", batch=False)
    def __chatgpt_function(self, content, gpt_model="gpt-4o-mini"):
        while True:
            try:
                client = OpenAI(
                    api_key=os.environ.get("OPENAI_API_KEY"),
                )
                chat_completion = client.chat.completions.create(
                    messages=[
                        {"role": "user", "content": content}
                    ],
                    model=gpt_model
                )
                return chat_completion.choices[0].message.content
            except openai.APIError as error:
                print(error)
